{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. MoE层里头到底有什么玩意"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from moe import MoE\n",
    "import torch as t\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "torch版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.1.0+cu121'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "moe_layer = MoE(input_size=114,output_size=514,num_experts=4,hidden_size=191,k=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MoE(\n",
       "  (experts): ModuleList(\n",
       "    (0-3): 4 x MLP(\n",
       "      (fc1): Linear(in_features=114, out_features=191, bias=True)\n",
       "      (fc2): Linear(in_features=191, out_features=514, bias=True)\n",
       "      (relu): ReLU()\n",
       "      (soft): Softmax(dim=1)\n",
       "    )\n",
       "  )\n",
       "  (softplus): Softplus(beta=1, threshold=20)\n",
       "  (softmax): Softmax(dim=1)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "moe_layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Forward 一次试试, forward 返回的是一个Tuple，一个是神经网络的输出，一个是Gate的Loss，这个Loss鼓励每个Expert被平均使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "output,loss = moe_layer.forward(t.randn(1,114))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 514])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0490, grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Param Name: w_gate, Shape: torch.Size([114, 4])\n",
      "Parameter containing:\n",
      "tensor([[0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0.]], requires_grad=True)\n",
      "Param Name: w_noise, Shape: torch.Size([114, 4])\n",
      "Param Name: experts.0.fc1.weight, Shape: torch.Size([191, 114])\n",
      "Param Name: experts.0.fc1.bias, Shape: torch.Size([191])\n",
      "Param Name: experts.0.fc2.weight, Shape: torch.Size([514, 191])\n",
      "Param Name: experts.0.fc2.bias, Shape: torch.Size([514])\n",
      "Param Name: experts.1.fc1.weight, Shape: torch.Size([191, 114])\n",
      "Param Name: experts.1.fc1.bias, Shape: torch.Size([191])\n",
      "Param Name: experts.1.fc2.weight, Shape: torch.Size([514, 191])\n",
      "Param Name: experts.1.fc2.bias, Shape: torch.Size([514])\n",
      "Param Name: experts.2.fc1.weight, Shape: torch.Size([191, 114])\n",
      "Param Name: experts.2.fc1.bias, Shape: torch.Size([191])\n",
      "Param Name: experts.2.fc2.weight, Shape: torch.Size([514, 191])\n",
      "Param Name: experts.2.fc2.bias, Shape: torch.Size([514])\n",
      "Param Name: experts.3.fc1.weight, Shape: torch.Size([191, 114])\n",
      "Param Name: experts.3.fc1.bias, Shape: torch.Size([191])\n",
      "Param Name: experts.3.fc2.weight, Shape: torch.Size([514, 191])\n",
      "Param Name: experts.3.fc2.bias, Shape: torch.Size([514])\n"
     ]
    }
   ],
   "source": [
    "for name,param in moe_layer.named_parameters():\n",
    "    print(\"Param Name: {}, Shape: {}\".format(name,param.shape))\n",
    "    if 'gate' in name:\n",
    "        print(param.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "确实是有4个Expert，然后每个Expert有两个Linear层\n",
    "\n",
    "另外两个参数是Gate的参数，shape是[114,4]，输出的是4个Expert的权重\n",
    "\n",
    "总的来说，从参数上看，就是一个神经网络复制了4份，然后使用Gate分配input到不同的Expert上\n",
    "\n",
    "Gate的参数开销可以忽略"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. MoE 的具体计算流程\n",
    "\n",
    "这个是MoE层的forward实现\n",
    "\n",
    "```python\n",
    "def forward(self, x, loss_coef=1e-2):\n",
    "        \"\"\"Args:\n",
    "        x: tensor shape [batch_size, input_size]\n",
    "        train: a boolean scalar.\n",
    "        loss_coef: a scalar - multiplier on load-balancing losses\n",
    "\n",
    "        Returns:\n",
    "        y: a tensor with shape [batch_size, output_size].\n",
    "        extra_training_loss: a scalar.  This should be added into the overall\n",
    "        training loss of the model.  The backpropagation of this loss\n",
    "        encourages all experts to be approximately equally used across a batch.\n",
    "        \"\"\"\n",
    "        gates, load = self.noisy_top_k_gating(x, self.training) \n",
    "        # calculate importance loss\n",
    "        importance = gates.sum(0)\n",
    "        #\n",
    "        loss = self.cv_squared(importance) + self.cv_squared(load)\n",
    "        loss *= loss_coef\n",
    "\n",
    "        dispatcher = SparseDispatcher(self.num_experts, gates)\n",
    "        expert_inputs = dispatcher.dispatch(x)\n",
    "        gates = dispatcher.expert_to_gates()\n",
    "        expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]\n",
    "        y = dispatcher.combine(expert_outputs)\n",
    "        return y, loss\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "手动算一下看看它输出什么玩意\n",
    "\n",
    "1. 计算expert的权重\n",
    "\n",
    "`gates, load = self.noisy_top_k_gating(x, self.training) `"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = t.randn((64,114)) # input batch 64 dim 114\n",
    "gates, load =moe_layer.noisy_top_k_gating(x, moe_layer.training) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 4])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gates.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "很显然gates输出的是每个数据样本应该被分配到哪个expert上"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 0., 0., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [1., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 1., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 1.],\n",
       "        [0., 1., 0., 0.]], grad_fn=<ScatterBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gates.sum(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从第1个维度(不是64的那个维度，因为我们维度是从0 开始的)，我们可以看到其实对于每个样本我们每个Expert的值之和是1，猜测是一个weighted average或者是其它用途，很显然这是手工Normalize过了。\n",
    "\n",
    "手动调整MoE初始化中的$k$，可以看到不一样的结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "load.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([17.0013, 16.1062, 16.3742, 15.0322], grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "load"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "load应该就是每个Expert所负责的Token（也就是x）的数量。所以叫做负载"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 计算Gate loss\n",
    "\n",
    "```python\n",
    "importance = gates.sum(0)\n",
    "loss = self.cv_squared(importance) + self.cv_squared(load)\n",
    "loss *= loss_coef\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([18., 15., 18., 13.], grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "importance = gates.sum(0)\n",
    "importance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算每个Expert的importance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss=moe_layer.cv_squared(importance)+moe_layer.cv_squared(load)\n",
    "loss = loss*1e-2 # coefficient is 1e-2 by default"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总的loss是Importance和Load的加权和，可以ctrl+鼠标左键点进去`cv_squared`看看里头是什么"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0003, grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总的Gate loss目标就是均匀使用每个Expert，在这个代码中是这样，但在其它库中是否是这样我不太清楚"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 根据Gate的结果分配Token\n",
    "```python\n",
    "dispatcher = SparseDispatcher(self.num_experts, gates)\n",
    "expert_inputs = dispatcher.dispatch(x)\n",
    "gates = dispatcher.expert_to_gates()\n",
    "expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]\n",
    "y = dispatcher.combine(expert_outputs)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from moe import SparseDispatcher\n",
    "#初始化Dispatcher\n",
    "dispatcher = SparseDispatcher(moe_layer,gates)\n",
    "expert_inputs = dispatcher.dispatch(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tuple"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(expert_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The input shape for 0-th expert: torch.Size([18, 114])\n",
      "The input shape for 1-th expert: torch.Size([15, 114])\n",
      "The input shape for 2-th expert: torch.Size([18, 114])\n",
      "The input shape for 3-th expert: torch.Size([13, 114])\n"
     ]
    }
   ],
   "source": [
    "for idx,ei in enumerate(expert_inputs):\n",
    "    print(\"The input shape for {}-th expert: {}\".format(idx,ei.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以看到，总体的和是64，这些token都被分给不同的Expert（根据Gate结果）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "gates = dispatcher.expert_to_gates()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]], grad_fn=<SplitWithSizesBackward0>),\n",
       " tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]], grad_fn=<SplitWithSizesBackward0>),\n",
       " tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]], grad_fn=<SplitWithSizesBackward0>),\n",
       " tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]], grad_fn=<SplitWithSizesBackward0>))"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这个是每个数据样本在Gate中的权重，如果k不=1的话会有不同的结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "expert_outputs = [moe_layer.experts[i](expert_inputs[i]) for i in range(moe_layer.num_experts)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这一句话就是将分配给的Token丢给Experts让他们算出结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The output shape for 0-th expert: torch.Size([18, 514])\n",
      "The output shape for 1-th expert: torch.Size([15, 514])\n",
      "The output shape for 2-th expert: torch.Size([18, 514])\n",
      "The output shape for 3-th expert: torch.Size([13, 514])\n"
     ]
    }
   ],
   "source": [
    "for idx,eo in enumerate(expert_outputs):\n",
    "    print(\"The output shape for {}-th expert: {}\".format(idx,eo.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = dispatcher.combine(expert_outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们将这些输出结合起来，注意k=1和k>1是不一样的表现。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 514])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以看到获得的输出就是我们想要的输出，就是64,514,这个就是单个Expert输出的shape\n",
    "\n",
    "其它问题：\n",
    "1. MoE层是每个Linear层我们都替换成MoELinear层（在fastmoe中是这样）但在这个例子中是每个MLP层（包含2个Linear层）我们替换成MoEMLP层\n",
    "2. **关于K，K我们可以通过调整K来调整我们一次性启用多少个Expert，这个跟我们的一些点子是冲突的（尤其是三原色）**"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
